package com.atg.ai_agent.rag;


/*
author: atg
time: 2025/10/4 11:22
*/

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer;
import org.stringtemplate.v4.ST;

/**
 * 查询重写 RewriteQueryTransformer
 */
public class ReWriteQueryTransformer {
    private final QueryTransformer queryTransformer;


    /**
     * 构造函数 注入
     * @param dashscopeChatModel
     */
    public ReWriteQueryTransformer(ChatModel dashscopeChatModel) {
        ChatClient.Builder builder = ChatClient.builder(dashscopeChatModel);
        // 创建查询重写转换器
        queryTransformer = RewriteQueryTransformer.builder()
                .chatClientBuilder(builder)
                .build();
    }

    // 重写查询
    public String rewriteQuery(String prompt) {
        Query query = new Query(prompt);
        Query transform = queryTransformer.transform(query);
        return transform.text();

    }


}


